查看原文
其他

欢迎试用 tf.function 加速代码

Google TensorFlow 2021-07-27

在 TensorFlow 2.0 中,默认情况下,Eager Execution 处于启用状态。这为您提供一个非常直观灵活的界面,可以提升运行一次性操作的简易性和速度,但会降低性能和可部署性。


为了获得峰值性能并使您的模型可以部署在任何位置,我们提供 tf.function,您可以将其用作工具,从程序中生成图表。

from __future__ import absolute_import, division, print_function, unicode_literals

!pip install -q tensorflow==2.0.0-alpha0
import tensorflow as tf

# 一个函数相当于一项操作

@tf.function
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]

<tf.Tensor: id=16, shape=(2, 2), dtype=float32, numpy=

array([[2., 2.],

       [2., 2.]], dtype=float32)>


您定义的 tf.function 相当于核心的 TensorFlow 操作:您可以立即执行该函数、可以在图表中使用该函数、该函数具有梯度,等等。

# 函数具有梯度

@tf.function
def add(a, b):
  return a + b

v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)

<tf.Tensor: id=44, shape=(), dtype=float32, numpy=1.0>

# 您可以在函数中使用函数

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))

<tf.Tensor: id=74, shape=(3, 2), dtype=float32, numpy=

array([[3., 3.],

       [3., 3.],

       [3., 3.]], dtype=float32)>



多态性

tf.function 试图成为和 Python 函数一样通用的函数。您可以使用各种签名调用 Python 函数,并且 Python 通常会进行一些合理的操作。即使 tf.function 生成的底层 TensorFlow 图表只适用于其签名中的特定类型,也会为您处理此类多态。


您可以调用具有不同类型参数的函数来查看发生的操作。

# 函数具有多态性

@tf.function
def add(a):
  return a + a

print("add 1", add(1))
print("add 1.1", add(1.1))
print("add string tensor", add(tf.constant("a")))
c = add.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
c(a=tf.constant("a"))  # aa

add 1 tf.Tensor(2, shape=(), dtype=int32)

add 1.1 tf.Tensor(2.2, shape=(), dtype=float32)

add string tensor tf.Tensor(b'aa', shape=(), dtype=string)

 

<tf.Tensor: id=104, shape=(), dtype=string, numpy=b'aa'>

# 对于含有许多小操作的图表而言,函数的运行速度比即时代码更快

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# 预热
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")

lstm_cell = tf.keras.layers.LSTMCell(10)

@tf.function
def lstm_fn(input, state):
  return lstm_cell(input, state)

input = tf.zeros([10, 10])
state = [tf.zeros([10, 10])] * 2
# 预热
lstm_cell(input, state); lstm_fn(input, state)
print("eager lstm:", timeit.timeit(lambda: lstm_cell(input, state), number=10))
print("function lstm:", timeit.timeit(lambda: lstm_fn(input, state), number=10))

Eager conv: 0.20972437999444082

Function conv: 0.21063927400973625

Note how there's not much difference in performance for convolutions

eager lstm: 0.033881522991578095

function lstm: 0.005326402999344282



tf.function 中的状态

在一般数据流图表中,tf.function 作为编程模型有一个非常有吸引力的函数属性,即函数可以为运行时提供关于代码期望行为定义的更多信息。


例如,在编写对相同变量具有多次读取和写入的代码时,数据流图表可能不会自然地对操作的最初期望顺序进行编码。然而,在 tf.function 中,由于我们要转换从 Python 追踪的代码,所以我们知道期望执行顺序。


这意味着我们无需添加手动控制依赖项;tf.function 足够智能,可以为代码添加必要的极小集和充分的控制依赖项,使其能够正确运行。

# 自动控制依赖项

a = tf.Variable(1.0)
b = tf.Variable(2.0)

@tf.function
def f(x, y):
  a.assign(y * b)
  b.assign_add(x * a)
  return a + b

f(1.0, 2.0)  # 10.0

<tf.Tensor: id=1610, shape=(), dtype=float32, numpy=10.0>



变量

我们可以使用与利用代码的期望执行顺序相同的方法大大简化 tf.function 中的变量创建和使用过程。然而,有一点需要注意,如果多次立即调用变量,或多次评估变量的输出张量,则我们使用变量编写出的代码行为可能有所不同。


简单示例如下:

@tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

f(1.) # 请注意:中断,将抛出异常


如果使用 Eager Execution 运行此代码,您将始终得到答案“2”,但如果在图表上下文中反复评估从 f(1.) 中获得的 Tensor,您将得到渐增的数字。


所以 tf.function 不允许您写入此类代码。

# 但无歧义代码运行正常

v = tf.Variable(1.0)

@tf.function
def f(x):
  return v.assign_add(x)

f(1.0)  # 2.0
f(2.0)  # 4.0

<tf.Tensor: id=1635, shape=(), dtype=float32, numpy=4.0>

# 您也可以在 tf.function 中创建变量,只要我们能够验证该类变量即可
# 您仅能在第一次执行函数时创建该类变量。

class C: pass
obj = C(); obj.v = None

@tf.function
def g(x):
  if obj.v is None:
    obj.v = tf.Variable(1.0)
  return obj.v.assign_add(x)

g(1.0)  # 2.0
g(2.0)  # 4.0

<tf.Tensor: id=1689, shape=(), dtype=float32, numpy=4.0>

# 变量初始化器可以依赖函数参数和其他
# 变量值。我们可以用与生成控制依赖项相同的方法
# 确定正确的初始化顺序。

state = []
@tf.function
def fn(x):
  if not state:
    state.append(tf.Variable(2.0 * x))
    state.append(tf.Variable(state[0] * 3.0))
  return state[0] * x * state[1]

fn(tf.constant(1.0))
fn(tf.constant(3.0))

WARNING:Logging before flag parsing goes to stderr.

W0418 23:39:09.505958 139706314610432 tf_logging.py:161] Entity <method-wrapper '__call__' of weakref object at 0x7f0f787bd1d8> could not be transformed and will be staged without change.Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1.Please report this to the AutoGraph team.Cause:Object conversion is not yet supported.If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion.For example, instead of converting the method of a class, try converting the entire class instead.See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.

W0418 23:39:09.517445 139706314610432 tf_logging.py:161] Entity <method-wrapper '__call__' of weakref object at 0x7f0f787bd958> could not be transformed and will be staged without change.Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1.Please report this to the AutoGraph team.Cause:Object conversion is not yet supported.If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion.For example, instead of converting the method of a class, try converting the entire class instead.See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.

 

WARNING:Entity <method-wrapper '__call__' of weakref object at 0x7f0f787bd1d8> could not be transformed and will be staged without change.Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1.Please report this to the AutoGraph team.Cause:Object conversion is not yet supported.If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion.For example, instead of converting the method of a class, try converting the entire class instead.See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.

WARNING:Entity <method-wrapper '__call__' of weakref object at 0x7f0f787bd958> could not be transformed and will be staged without change.Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1.Please report this to the AutoGraph team.Cause:Object conversion is not yet supported.If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion.For example, instead of converting the method of a class, try converting the entire class instead.See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.

 

<tf.Tensor: id=1796, shape=(), dtype=float32, numpy=36.0>



控制流和 AutoGraph

在 tf.cond 和 tf.while_loop 继续使用 tf.function 工作的同时,我们以 Python 代码的轻量级编译为基础,提供更好的替代方案。


AutoGraph 库与 tf.function 实现完整集成,其将重写依赖 Tensors 的条件语句和循环语句,以在图表中动态运行。

# 简单循环语句

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([10]))

[0.690678835 0.687305927 0.280717611 ... 0.481444716 0.331221104 0.0514520407]

[0.598417938 0.596248507 0.273569077 ... 0.447399884 0.31961754 0.0514066778]

[0.535922825 0.534374714 0.26694271 ... 0.419759363 0.309161037 0.0513614379]

[0.489895523 0.488718063 0.260777712 ... 0.396727681 0.299673647 0.0513163097]

[0.454133481 0.453198373 0.255022794 ... 0.377145559 0.291013896 0.0512713082]

[0.425290793 0.424524486 0.249634281 ... 0.360225976 0.283067733 0.0512264259]

[0.401378244 0.400735199 0.244574845 ... 0.345413029 0.275741935 0.0511816591]

[0.381127626 0.380577862 0.239812165 ... 0.332301289 0.268959552 0.051137004]

[0.363686323 0.363209188 0.235318303 ... 0.320587069 0.26265642 0.0510924757]

[0.348456889 0.34803763 0.231068835 ... 0.310037643 0.256778479 0.0510480627]

[0.335006297 0.334634066 0.227042317 ... 0.300471336 0.251279861 0.0510037579]

[0.323011965 0.322678506 0.223219901 ... 0.291743845 0.246121347 0.0509595685]

[0.312227786 0.311926812 0.219584852 ... 0.28373903 0.241269216 0.0509155]

[0.302462429 0.302188963 0.216122329 ... 0.276362091 0.236694187 0.0508715473]

[0.293564439 0.293314546 0.212819085 ... 0.269534767 0.232370824 0.0508277]

[0.285412192 0.285182655 0.209663227 ... 0.263191849 0.228276819 0.0507839732]

[0.277906716 0.277694911 0.206644118 ... 0.257278532 0.224392563 0.0507403575]

[0.270966589 0.270770341 0.203752115 ... 0.251748294 0.220700681 0.0506968535]

[0.264523983 0.264341474 0.200978562 ... 0.246561363 0.217185751 0.0506534651]

[0.258522063 0.258351713 0.198315561 ... 0.241683558 0.213834107 0.0506101772]

[0.252912641 0.252753168 0.195755959 ... 0.237085283 0.210633427 0.0505670048]

[0.247654602 0.247504905 0.193293214 ... 0.232740775 0.207572699 0.0505239442]

[0.242712677 0.242571801 0.190921336 ... 0.228627458 0.204641983 0.0504809953]

[0.238056332 0.237923414 0.188634917 ... 0.2247255 0.201832339 0.0504381545]

[0.233659014 0.233533338 0.18642889 ... 0.221017376 0.199135616 0.0503954217]

[0.229497537 0.229378462 0.184298649 ... 0.217487514 0.196544439 0.0503527932]

[0.225551486 0.225438461 0.18223998 ... 0.214122042 0.194052115 0.0503102802]

[0.22180286 0.221695408 0.180248916 ... 0.210908577 0.191652477 0.0502678677]

[0.218235701 0.21813339 0.178321853 ... 0.207835972 0.189339921 0.0502255671]

[0.214835808 0.21473822 0.176455453 ... 0.20489423 0.187109306 0.0501833707]

[0.211590484 0.211497262 0.174646571 ... 0.202074304 0.18495588 0.0501412824]

[0.208488345 0.208399177 0.172892302 ... 0.199367985 0.182875291 0.0500993]

[0.20551914 0.205433741 0.171189949 ... 0.196767837 0.180863515 0.0500574186]

[0.202673614 0.202591732 0.169537008 ... 0.194267079 0.178916842 0.0500156432]

[0.199943423 0.19986479 0.167931125 ... 0.191859543 0.177031845 0.0499739796]

[0.197320923 0.197245359 0.166370124 ... 0.189539552 0.17520532 0.0499324165]

[0.19479923 0.194726542 0.164851919 ... 0.187301934 0.173434287 0.0498909578]

[0.192372054 0.192302063 0.163374603 ... 0.185141906 0.171716 0.0498496]

[0.1900336 0.189966142 0.161936387 ... 0.183055103 0.170047909 0.0498083457]

[0.187778607 0.187713534 0.160535559 ... 0.181037456 0.168427587 0.0497671925]

[0.185602203 0.18553938 0.159170523 ... 0.17908524 0.166852787 0.0497261435]

[0.183499932 0.183439225 0.15783979 ... 0.177194953 0.16532144 0.0496851951]

[0.181467682 0.181408942 0.156541929 ... 0.175363421 0.163831577 0.049644351]

[0.179501608 0.17944476 0.155275628 ... 0.173587635 0.162381351 0.0496036038]

[0.177598223 0.177543178 0.154039606 ... 0.171864837 0.160969034 0.0495629534]

[0.175754249 0.175700903 0.152832672 ... 0.17019242 0.159592986 0.0495224036]

[0.173966661 0.173914909 0.151653722 ... 0.168567985 0.158251703 0.0494819507]

[0.172232628 0.172182426 0.150501683 ... 0.166989282 0.156943724 0.0494416021]

[0.170549557 0.1705008 0.149375558 ... 0.165454194 0.155667722 0.0494013466]

[0.168914959 0.168867588 0.148274377 ... 0.163960755 0.154422387 0.0493611954]

[0.16732657 0.167280525 0.147197232 ... 0.162507117 0.153206512 0.0493211411]

[0.165782228 0.165737465 0.146143243 ... 0.161091521 0.152018949 0.0492811799]

[0.164279968 0.164236397 0.145111606 ... 0.159712359 0.150858626 0.0492413193]

[0.162817866 0.162775457 0.144101545 ... 0.158368081 0.149724513 0.0492015518]

[0.161394194 0.161352888 0.143112317 ... 0.157057241 0.148615628 0.0491618849]

[0.160007298 0.159967035 0.14214322 ... 0.155778468 0.147531062 0.0491223149]

[0.158655599 0.158616349 0.141193554 ... 0.154530466 0.146469936 0.049082838]

[0.157337651 0.15729937 0.140262708 ... 0.153312042 0.145431399 0.049043458]

[0.156052053 0.156014711 0.139350057 ... 0.152122036 0.144414678 0.0490041673]

[0.154797524 0.154761076 0.138455018 ... 0.150959358 0.143419027 0.0489649773]

[0.153572813 0.153537214 0.137577027 ... 0.149822965 0.142443702 0.0489258766]

[0.152376771 0.152341992 0.136715531 ... 0.148711905 0.141488045 0.0488868728]

[0.151208282 0.151174292 0.135870054 ... 0.147625253 0.140551403 0.0488479622]

[0.150066301 0.150033072 0.135040075 ... 0.146562085 0.139633134 0.0488091446]

[0.148949862 0.148917362 0.13422516 ... 0.145521596 0.138732657 0.0487704165]

[0.147858009 0.14782621 0.133424833 ... 0.144503 0.137849391 0.0487317815]

[0.146789849 0.146758735 0.132638663 ... 0.143505514 0.136982813 0.048693236]

[0.145744532 0.145714089 0.131866246 ... 0.142528445 0.136132374 0.0486547835]

[0.144721285 0.144691452 0.131107181 ... 0.14157109 0.135297611 0.0486164242]

[0.143719301 0.143690094 0.130361095 ... 0.140632793 0.134478047 0.048578158]

[0.14273788 0.14270927 0.129627615 ... 0.139712915 0.133673206 0.0485399812]

[0.141776308 0.141748279 0.128906399 ... 0.138810888 0.132882655 0.0485018939]

[0.140833959 0.140806481 0.128197089 ... 0.137926146 0.132105991 0.0484638959]

[0.139910176 0.13988322 0.127499372 ... 0.137058124 0.131342798 0.0484259836]

[0.13900435 0.138977915 0.126812935 ... 0.136206314 0.130592704 0.048388157]

[0.138115913 0.13808997 0.12613748 ... 0.135370195 0.12985532 0.0483504198]

[0.137244314 0.137218863 0.125472724 ... 0.134549305 0.129130304 0.0483127683]

[0.136389032 0.136364058 0.124818385 ... 0.133743197 0.128417313 0.0482752062]

[0.135549575 0.135525048 0.124174185 ... 0.132951409 0.12771602 0.0482377335]

[0.134725437 0.134701341 0.123539865 ... 0.132173553 0.127026111 0.0482003503]

[0.133916169 0.133892506 0.122915179 ... 0.131409198 0.126347259 0.0481630526]

[0.133121327 0.133098081 0.122299887 ... 0.130657956 0.125679195 0.0481258407]

[0.132340491 0.132317647 0.121693745 ... 0.129919484 0.125021622 0.0480887182]

[0.13157326 0.131550804 0.121096537 ... 0.129193395 0.124374263 0.0480516776]

[0.130819231 0.130797148 0.120508038 ... 0.128479362 0.123736873 0.0480147265]

[0.130078033 0.130056322 0.119928055 ... 0.127777055 0.123109199 0.0479778573]

[0.129349306 0.129327953 0.119356371 ... 0.127086148 0.122490987 0.0479410738]

[0.128632709 0.128611699 0.118792787 ... 0.126406357 0.121882014 0.0479043722]

[0.127927899 0.127907217 0.118237123 ... 0.125737369 0.121282049 0.0478677601]

[0.127234563 0.127214208 0.1176892 ... 0.125078887 0.120690845 0.0478312299]

[0.126552388 0.126532361 0.117148817 ... 0.124430671 0.120108217 0.0477947854]

[0.125881076 0.125861362 0.116615817 ... 0.123792425 0.119533956 0.0477584228]

[0.125220343 0.125200942 0.116090037 ... 0.123163909 0.118967861 0.0477221459]

[0.124569915 0.124550819 0.115571305 ... 0.122544892 0.118409738 0.0476859473]

[0.123929545 0.123910733 0.115059465 ... 0.121935107 0.117859408 0.0476498269]

[0.123298958 0.123280421 0.114554383 ... 0.121334352 0.117316693 0.0476137921]

[0.1226779 0.122659646 0.114055909 ... 0.120742403 0.116781399 0.0475778393]

[0.12206614 0.122048169 0.113563888 ... 0.120159045 0.116253398 0.0475419648]

[0.121463455 0.121445753 0.113078192 ... 0.119584054 0.115732491 0.0475061722]

[0.120869622 0.12085218 0.112598673 ... 0.119017266 0.115218528 0.0474704653]

[0.120284423 0.12026722 0.11212521 ... 0.118458457 0.114711344 0.0474348329]

[0.119707651 0.119690686 0.111657664 ... 0.11790745 0.114210814 0.0473992862]

[0.119139098 0.119122371 0.111195929 ... 0.117364064 0.113716781 0.0473638177]

[0.118578568 0.118562095 0.110739879 ... 0.116828129 0.113229126 0.0473284237]

[0.118025899 0.118009649 0.110289395 ... 0.116299465 0.112747692 0.0472931154]

[0.117480896 0.117464855 0.109844379 ... 0.11577794 0.112272345 0.0472578816]

 

<tf.Tensor: id=1841, shape=(10,), dtype=float32, numpy=

array([0.11694337, 0.11692756, 0.10940471, 0.11663772, 0.04113298,

       0.11698803, 0.10682372, 0.11526338, 0.11180297, 0.04722273],

      dtype=float32)>

# 如您有意了解详情,可以查看 AutoGraph 生成的代码。
# 不过,这感觉像是在阅读汇编语言。

def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

print(tf.autograph.to_code(f))

from __future__ import print_function

 

def tf__f(x):

  try:

    with ag__.function_scope('f'):

      do_return = False

      retval_ = None

 

      def loop_test(x_1):

        with ag__.function_scope('loop_test'):

          return ag__.gt(ag__.converted_call('reduce_sum', tf, ag__.ConversionOptions(recursive=True, verbose=0, strip_decorators=(ag__.convert, ag__.do_not_convert, ag__.converted_call), force_conversion=False, optional_features=ag__.Feature.ALL, internal_convert_user_code=True), (x_1,), {}), 1)

 

      def loop_body(x_1):

        with ag__.function_scope('loop_body'):

          with ag__.utils.control_dependency_on_returns(ag__.converted_call('print', tf, ag__.ConversionOptions(recursive=True, verbose=0, strip_decorators=(ag__.convert, ag__.do_not_convert, ag__.converted_call), force_conversion=False, optional_features=ag__.Feature.ALL, internal_convert_user_code=True), (x_1,), {})):

            tf_1, x = ag__.utils.alias_tensors(tf, x_1)

            x = ag__.converted_call('tanh', tf_1, ag__.ConversionOptions(recursive=True, verbose=0, strip_decorators=(ag__.convert, ag__.do_not_convert, ag__.converted_call), force_conversion=False, optional_features=ag__.Feature.ALL, internal_convert_user_code=True), (x,), {})

            return x,

      x, = ag__.while_stmt(loop_test, loop_body, (x,), (tf, x, ag__))

      do_return = True

      retval_ = x

      return retval_

  except:

    ag__.rewrite_graph_construction_error(ag_source_map__)

 

 

 

tf__f.autograph_info__ = {}


为控制 AutoGraph,请记住该库只影响 Python 中的基本控制流构造(if、for、while、break 等),并且其只在谓词为 Tensor 时才会更改这些构造。


因此在下面的例子中,第一个循环经过静态展开,而第二个循环经过动态转换:

@tf.function
def f(x):
  for i in range(10):  # 静态 Python 循环,我们不会转换此循环
    do_stuff()
  for i in tf.range(10):  # 依赖于张量,我们会转换此循环


同样地,为确保打印输出和断言动态发生,请使用 tf.print 和 tf.assert:

@tf.function
def f(x):
  for i in tf.range(10):
    tf.print(i)
    tf.Assert(i < 10, ["a"])
    x += x
  return x

f(10)

0

1

2

3

4

5

6

7

8

9

 

<tf.Tensor: id=1904, shape=(), dtype=int32, numpy=10240>


最后,AutoGraph 无法将任意 Python 代码编译为 TensorFlow 图。具体来说,您动态使用的数据结构仍需为 TensorFlow 数据结构。


因此,举例而言,在循环中累积数据的最佳方法仍然是使用 tf.TensorArray:

@tf.function
def f(x):
  ta = tf.TensorArray(tf.float32, size=10)
  for i in tf.range(10):
    x += x
    ta = ta.write(i, x)
  return ta.stack()

f(10.0)

<tf.Tensor: id=1973, shape=(10,), dtype=float32, numpy=

array([   20.,    40.,    80.,   160.,   320.,   640.,  1280.,  2560.,

        5120., 10240.], dtype=float32)>



后续步骤

现在请您重温之前的笔记,并试用 tf.function 加速代码!



更多 AI 相关阅读:



    您可能也对以下帖子感兴趣

    文章有问题?点此查看未经处理的缓存